#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import math
import random
import numpy as np
import torch
import scipy
import warnings
from ..layers import functional
from . import function_utils


#######################################################################
# not used for now
# def reshape_input_4d_fast(input, num_tiles_y, num_tiles_x):
#     # num_tiles_x = args.model_config.num_tiles_x
#     # num_tiles_y = args.model_config.num_tiles_y

#     if num_tiles_x == 1 and num_tiles_y == 1:
#         return input

#     num_sub_img = num_tiles_x * num_tiles_y
#     bs, n_col, tile_h, tile_w = input.shape
#     assert (bs % num_sub_img == 0)

#     combined_input_batch = torch.zeros([bs // num_sub_img, n_col, tile_h * num_tiles_y, tile_w * num_tiles_x],
#                                        dtype=torch.float32, device=input.device)
#     # combined_target_batch = []

#     # print(combined_input_batch.shape)
#     # combined_target = []
#     for new_idx, b_idx in enumerate(range(0, bs, num_sub_img)):
#         tmp = torch.cat([input[b_idx + local_idx, :, :, :] for local_idx in range(num_sub_img)], 1)
#         print(tmp.shape)
#         print(tmp.reshape([1, 3, tile_h * num_tiles_y, tile_w * num_tiles_x]).shape)
#         combined_input_batch[new_idx, :, :, :] = torch.reshape(tmp, (1, 3, tile_h * num_tiles_y, tile_w * num_tiles_x))
#     return combined_input_batch


# resize input from shape [bs,n_col,h, w] to [bs/(n_tile_y*n_tile_x), n_col, h*n_tile_y, w*n_tile_x]
# needed for multi label training
def reshape_input_4d(input, num_tiles_y, num_tiles_x):
    # num_tiles_x = args.model_config.num_tiles_x
    # num_tiles_y = args.model_config.num_tiles_y

    if num_tiles_x == 1 and num_tiles_y == 1:
        return input

    num_sub_img = num_tiles_x * num_tiles_y
    bs, n_col, tile_h, tile_w = input.shape
    assert (bs % num_sub_img == 0)

    combined_input_batch = torch.zeros([bs // num_sub_img, n_col, tile_h * num_tiles_y, tile_w * num_tiles_x],
                                       dtype=torch.float32, device=input.device)

    # print(combined_input_batch.shape)
    for new_idx, b_idx in enumerate(range(0, bs, num_sub_img)):
        # print("img_idx", index)
        for local_idx in range(num_sub_img):
            x_offset = (local_idx % num_tiles_x) * tile_w
            y_offset = (local_idx // num_tiles_x) * tile_h
            # print("input[b_idx + local_idx,:,:,:].shape: ", input[b_idx + local_idx,:,:,:].shape)
            # print(x_offset, y_offset, new_idx, local_idx, b_idx)
            combined_input_batch[new_idx,:,y_offset:y_offset+tile_h,x_offset:x_offset+tile_w] = input[b_idx+local_idx,:,:,:]

            # print("img.shape: ", img.shape)
            # print("x_offset: ", x_offset)
            # print("y_offset: ", y_offset)
    return combined_input_batch